/*
 * get_encryption_keys.go
 *
 * This source file is part of the FoundationDB open source project
 *
 * Copyright 2013-2024 Apple Inc. and the FoundationDB project authors
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

// GetEncryptionKeys handler
// Handler is resposible for the following:
// 1. Parse the incoming HttpRequest and validate JSON request structural sanity
// 2. Ability to handle getEncryptionKeys by 'KeyId' or 'DomainId' as requested
// 3. Ability to inject faults if requested

package main

import (
    "encoding/json"
    "fmt"
    "io/ioutil"
    "log"
    "math/rand"
    "net/http"
)

type CipherDetailRes struct {
    BaseCipherId uint64                    `json:"base_cipher_id"`
    EncryptDomainId int64                  `json:"encrypt_domain_id"`
    BaseCipher string                      `json:"base_cipher"`
}

type ValidationToken struct {
    TokenName string                       `json:"token_name"`
    TokenValue string                      `json:"token_value"`
}

type CipherDetailReq struct {
    BaseCipherId uint64                   `json:"base_cipher_id"`
    EncryptDomainId int64                 `json:"encrypt_domain_id"`
}

type GetEncryptKeysResponse struct {
    CipherDetails []CipherDetailRes        `json:"cipher_key_details"`
    KmsUrls []string                       `json:"kms_urls"`
}

type GetEncryptKeysRequest struct {
    QueryMode string                      `json:"query_mode"`
    CipherDetails []CipherDetailReq       `json:"cipher_key_details"`
    ValidationTokens []ValidationToken    `json:"validation_tokens"`
    RefreshKmsUrls bool                   `json:"refresh_kms_urls"`
}

type cipherMapInstanceSingleton map[uint64][]byte

const (
	READ_HTTP_REQUEST_BODY = iota
	UNMARSHAL_REQUEST_BODY_JSON
	UNSUPPORTED_QUERY_MODE
	PARSE_HTTP_REQUEST
	MARSHAL_RESPONSE
)

const (
    maxCipherKeys = uint64(1024*1024)                   // Max cipher keys
    maxCipherSize = 16                                  // Max cipher buffer size
)

var (
    cipherMapInstance cipherMapInstanceSingleton        // Singleton mapping of { baseCipherId -> baseCipher }
)

// const mapping of { Location -> errorString }
func errStrMap() func(int) string {
    _errStrMap := map[int]string {
        READ_HTTP_REQUEST_BODY : "Http request body read error",
        UNMARSHAL_REQUEST_BODY_JSON : "Http request body unmarshal error",
        UNSUPPORTED_QUERY_MODE : "Unsupported query_mode",
        PARSE_HTTP_REQUEST : "Error parsing GetEncryptionKeys request",
        MARSHAL_RESPONSE : "Error marshaling response",
    }

    return func(key int) string {
        return _errStrMap[key]
    }
}

// Caller is responsible for thread synchronization. Recommended to be invoked during package::init()
func NewCipherMap(maxKeys uint64, cipherSize int) cipherMapInstanceSingleton {
    if cipherMapInstance == nil {
        cipherMapInstance = make(map[uint64][]byte)

        for i := uint64(1); i<= maxKeys; i++ {
            cipher := make([]byte, cipherSize)
            rand.Read(cipher)
            cipherMapInstance[i] = cipher
        }
        log.Printf("KMS cipher map populate done, maxCiphers '%d'", maxCipherKeys)
    }
    return cipherMapInstance
}

func getKmsUrls() (urls []string) {
    urlCount := rand.Intn(5) + 1
    for i := 1; i <= urlCount; i++ {
        url := fmt.Sprintf("https://KMS/%d:%d:%d:%d", i, i, i, i)
        urls = append(urls, url)
    }
    return
}

func isEncryptDomainIdValid(id int64) bool {
    if id > 0 || id == -1 || id == -2 {
        return true
    }
    return false
}

func abs(x int64) int64 {
    if x < 0 {
        return -x
    }
    return x
}

func getBaseCipherIdFromDomainId(domainId int64) (baseCipherId uint64) {
    baseCipherId = uint64(1) + uint64(abs(domainId)) % maxCipherKeys
    return
}

func getEncryptionKeysByKeyIds(w http.ResponseWriter, byteArr []byte) {
    req := GetEncryptKeysRequest{}
    err := json.Unmarshal(byteArr, &req)
    if err != nil || shouldInjectFault(PARSE_HTTP_REQUEST) {
        var e error
        if shouldInjectFault(PARSE_HTTP_REQUEST) {
            e = fmt.Errorf("[FAULT] %s %s'", errStrMap()(PARSE_HTTP_REQUEST), string(byteArr))
        } else {
            e = fmt.Errorf("%s %s' err '%v'", errStrMap()(PARSE_HTTP_REQUEST), string(byteArr), err)
        }
        log.Println(e.Error())
        sendErrorResponse(w, e)
        return
    }

    var details []CipherDetailRes
    for i := 0; i < len(req.CipherDetails); i++ {
        var baseCipherId = uint64(req.CipherDetails[i].BaseCipherId)

        var encryptDomainId = int64(req.CipherDetails[i].EncryptDomainId)
        if !isEncryptDomainIdValid(encryptDomainId) {
            e := fmt.Errorf("EncryptDomainId not valid '%d'", encryptDomainId)
            sendErrorResponse(w, e)
            return
        }

        cipher, found := cipherMapInstance[baseCipherId]
        if !found {
            e := fmt.Errorf("BaseCipherId not found '%d'", baseCipherId)
            sendErrorResponse(w, e)
            return
        }

        var detail = CipherDetailRes {
            BaseCipherId:       baseCipherId,
            EncryptDomainId:    encryptDomainId,
            BaseCipher:         string(cipher),
        }
        details = append(details, detail)
    }

    var urls []string
    if req.RefreshKmsUrls {
        urls = getKmsUrls()
    }

    resp := GetEncryptKeysResponse{
        CipherDetails:  details,
        KmsUrls:        urls,
    }

    mResp, err := json.Marshal(resp)
    if err != nil || shouldInjectFault(MARSHAL_RESPONSE) {
        var e error
        if shouldInjectFault(MARSHAL_RESPONSE) {
            e = fmt.Errorf("[FAULT] %s", errStrMap()(MARSHAL_RESPONSE))
        } else {
            e = fmt.Errorf("%s err '%v'", errStrMap()(MARSHAL_RESPONSE), err)
        }
        log.Println(e.Error())
        sendErrorResponse(w, e)
        return
    }

    fmt.Fprintf(w, string(mResp))
}

func getEncryptionKeysByDomainIds(w http.ResponseWriter, byteArr []byte) {
	req := GetEncryptKeysRequest{}
    err := json.Unmarshal(byteArr, &req)
    if err != nil || shouldInjectFault(PARSE_HTTP_REQUEST) {
        var e error
        if shouldInjectFault(PARSE_HTTP_REQUEST) {
            e = fmt.Errorf("[FAULT] %s '%s'", errStrMap()(PARSE_HTTP_REQUEST), string(byteArr))
        } else {
            e = fmt.Errorf("%s '%s' err '%v'", errStrMap()(PARSE_HTTP_REQUEST), string(byteArr), err)
        }
        log.Println(e.Error())
        sendErrorResponse(w, e)
        return
    }

    var details []CipherDetailRes
    for i := 0; i < len(req.CipherDetails); i++ {
        var encryptDomainId = int64(req.CipherDetails[i].EncryptDomainId)
        if !isEncryptDomainIdValid(encryptDomainId) {
            e := fmt.Errorf("EncryptDomainId not valid '%d'", encryptDomainId)
            sendErrorResponse(w, e)
            return
        }

        var baseCipherId = getBaseCipherIdFromDomainId(encryptDomainId)
        cipher, found := cipherMapInstance[baseCipherId]
        if !found {
            e := fmt.Errorf("BaseCipherId not found '%d'", baseCipherId)
            sendErrorResponse(w, e)
            return
        }

        var detail = CipherDetailRes {
            BaseCipherId:       baseCipherId,
            EncryptDomainId:    encryptDomainId,
            BaseCipher:         string(cipher),
        }
        details = append(details, detail)
    }

    var urls []string
    if req.RefreshKmsUrls {
        urls = getKmsUrls()
    }

    resp := GetEncryptKeysResponse{
        CipherDetails:  details,
        KmsUrls:        urls,
    }

    mResp, err := json.Marshal(resp)
    if err != nil || shouldInjectFault(MARSHAL_RESPONSE) {
        var e error
        if shouldInjectFault(MARSHAL_RESPONSE) {
            e = fmt.Errorf("[FAULT] %s", errStrMap()(MARSHAL_RESPONSE))
        } else {
            e = fmt.Errorf("%s err '%v'", errStrMap()(MARSHAL_RESPONSE), err)
        }
        log.Println(e.Error())
        sendErrorResponse(w, e)
        return
    }

    fmt.Fprintf(w, string(mResp))
}

func handleGetEncryptionKeys(w http.ResponseWriter, r *http.Request) {
    byteArr, err := ioutil.ReadAll(r.Body)
    if err != nil || shouldInjectFault(READ_HTTP_REQUEST_BODY) {
        var e error
        if shouldInjectFault(READ_HTTP_REQUEST_BODY) {
            e = fmt.Errorf("[FAULT] %s", errStrMap()(READ_HTTP_REQUEST_BODY))
        } else {
            e = fmt.Errorf("%s err '%v'", errStrMap()(READ_HTTP_REQUEST_BODY), err)
        }
        log.Println(e.Error())
        sendErrorResponse(w, e)
        return
    }

    var arbitrary_json map[string]interface{}
    err = json.Unmarshal(byteArr, &arbitrary_json)
    if err != nil || shouldInjectFault(UNMARSHAL_REQUEST_BODY_JSON) {
        var e error
        if shouldInjectFault(UNMARSHAL_REQUEST_BODY_JSON) {
            e = fmt.Errorf("[FAULT] %s", errStrMap()(UNMARSHAL_REQUEST_BODY_JSON))
        } else {
            e = fmt.Errorf("%s err '%v'", errStrMap()(UNMARSHAL_REQUEST_BODY_JSON), err)
        }
        log.Println(e.Error())
        sendErrorResponse(w, e)
        return
    }

    if shouldInjectFault(UNSUPPORTED_QUERY_MODE) {
        err = fmt.Errorf("[FAULT] %s '%s'", errStrMap()(UNSUPPORTED_QUERY_MODE), arbitrary_json["query_mode"])
        sendErrorResponse(w, err)
        return
    } else if arbitrary_json["query_mode"] == "lookupByKeyId" {
        getEncryptionKeysByKeyIds(w, byteArr)
    } else if arbitrary_json["query_mode"] == "lookupByDomainId" {
        getEncryptionKeysByDomainIds(w, byteArr)
    } else {
        err = fmt.Errorf("%s '%s'", errStrMap()(UNSUPPORTED_QUERY_MODE), arbitrary_json["query_mode"])
        sendErrorResponse(w, err)
        return
    }
}

func initEncryptCipherMap() {
    cipherMapInstance = NewCipherMap(maxCipherKeys, maxCipherSize)
}